{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ],
      "metadata": {
        "id": "SGjEjVxwudf2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Apache Beam RunInference with Hugging Face\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_huggingface.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_huggingface.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>"
      ],
      "metadata": {
        "id": "BJ0mTOhFucGg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "This notebook shows how to use models from [Hugging Face](https://huggingface.co/) and [Hugging Face pipeline](https://huggingface.co/docs/transformers/main_classes/pipelines) in Apache Beam pipelines that uses the [RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) transform.\n",
        "\n",
        "Apache Beam has built-in support for Hugging Face model handlers. Hugging Face has three model handlers:\n",
        "\n",
        "\n",
        "\n",
        "*   Use the [`HuggingFacePipelineModelHandler`](https://github.com/apache/beam/blob/926774dd02be5eacbe899ee5eceab23afb30abca/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L567) model handler to run inference with [Hugging Face pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines#pipelines).\n",
        "*   Use the [`HuggingFaceModelHandlerKeyedTensor`](https://github.com/apache/beam/blob/926774dd02be5eacbe899ee5eceab23afb30abca/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L208) model handler to run inference with models that uses keyed tensors as inputs. For example, you might use this model handler with language modeling tasks.\n",
        "*   Use the [`HuggingFaceModelHandlerTensor`](https://github.com/apache/beam/blob/926774dd02be5eacbe899ee5eceab23afb30abca/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L392) model handler to run inference with models that uses tensor inputs, such as `tf.Tensor` or `torch.Tensor`.    \n",
        "\n",
        "\n",
        "For more information about using RunInference, see [Get started with AI/ML pipelines](https://beam.apache.org/documentation/ml/overview/) in the Apache Beam documentation."
      ],
      "metadata": {
        "id": "GBloorZevCXC"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Install dependencies"
      ],
      "metadata": {
        "id": "xpylrB7_2Jzk"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Install both Apache Beam and the required dependencies for Hugging Face."
      ],
      "metadata": {
        "id": "IBQLg8on2S40"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yrqNSBB3qsI1"
      },
      "outputs": [],
      "source": [
        "!pip install torch --quiet\n",
        "!pip install tensorflow --quiet\n",
        "!pip install transformers==4.44.2 --quiet\n",
        "!pip install apache-beam[gcp]>=2.50 --quiet"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from typing import Dict\n",
        "from typing import Iterable\n",
        "from typing import Tuple\n",
        "\n",
        "import tensorflow as tf\n",
        "import torch\n",
        "from transformers import AutoTokenizer\n",
        "from transformers import TFAutoModelForMaskedLM\n",
        "\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference.base import KeyedModelHandler\n",
        "from apache_beam.ml.inference.base import PredictionResult\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.ml.inference.huggingface_inference import HuggingFacePipelineModelHandler\n",
        "from apache_beam.ml.inference.huggingface_inference import HuggingFaceModelHandlerKeyedTensor\n",
        "from apache_beam.ml.inference.huggingface_inference import HuggingFaceModelHandlerTensor\n",
        "from apache_beam.ml.inference.huggingface_inference import PipelineTask\n"
      ],
      "metadata": {
        "id": "BIDZLGFRrAmF"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Use RunInference with Hugging Face pipelines"
      ],
      "metadata": {
        "id": "OQ1wv6xk3UeV"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "You can use [Hugging Face pipelines](https://huggingface.co/docs/transformers/main_classes/pipelines#pipelines) with `RunInference` by using the `HuggingFacePipelineModelHandler` model handler. Similar to the Hugging Face pipelines, to instantiate the model handler, the model handler needs either the pipeline `task` or the `model` that defines the task. To pass any optional arguments to load the pipeline, use `load_pipeline_args`. To pass the optional arguments for inference, use `inference_args`.\n",
        "\n",
        "\n",
        "\n",
        "You can define the pipeline task in one of the following two ways:\n",
        "\n",
        "\n",
        "\n",
        "*   In the form of string, for example `\"translation\"`. This option is similar to how the pipeline task is defined when using Hugging Face.\n",
        "*   In the form of a [`PipelineTask`](https://github.com/apache/beam/blob/ac936b0b89a92d836af59f3fc04f5733ad6819b3/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L75) enum object defined in Apache Beam, such as `PipelineTask.Translation`.\n"
      ],
      "metadata": {
        "id": "hWZQ49Pt3ojg"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Create a model handler\n",
        "\n",
        "This example demonstrates a task that translates text from English to Spanish."
      ],
      "metadata": {
        "id": "pVVg9RfET86L"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model_handler = HuggingFacePipelineModelHandler(\n",
        "    task=PipelineTask.Translation_XX_to_YY,\n",
        "    model = \"google/flan-t5-small\",\n",
        "    load_pipeline_args={'framework': 'pt'},\n",
        "    inference_args={'max_length': 200}\n",
        ")"
      ],
      "metadata": {
        "id": "aF_BDPXk3UG4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Define the input examples\n",
        "\n",
        "Use this code to define the input examples."
      ],
      "metadata": {
        "id": "lxTImFGJUBIw"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "text = [\"translate English to Spanish: How are you doing?\",\n",
        "        \"translate English to Spanish: This is the Apache Beam project.\"]"
      ],
      "metadata": {
        "id": "POAuIFS_UDgE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Postprocess the results\n",
        "\n",
        "The output from the `RunInference` transform is a `PredictionResult` object. Use that output to extract inferences, and then format and print the results."
      ],
      "metadata": {
        "id": "Ay6-7DD5TZLn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class FormatOutput(beam.DoFn):\n",
        "  \"\"\"\n",
        "  Extract the results from PredictionResult and print the results.\n",
        "  \"\"\"\n",
        "  def process(self, element):\n",
        "    example = element.example\n",
        "    translated_text = element.inference[0]['translation_text']\n",
        "    print(f'Example: {example}')\n",
        "    print(f'Translated text: {translated_text}')\n",
        "    print('-' * 80)\n"
      ],
      "metadata": {
        "id": "74I3U1JsrG0R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Run the pipeline\n",
        "\n",
        "Use the following code to run the pipeline."
      ],
      "metadata": {
        "id": "Ve2cpHZ_UH0o"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with beam.Pipeline() as beam_pipeline:\n",
        "  examples = (\n",
        "      beam_pipeline\n",
        "      | \"CreateExamples\" >> beam.Create(text)\n",
        "  )\n",
        "  inferences = (\n",
        "      examples\n",
        "      | \"RunInference\" >> RunInference(model_handler)\n",
        "      | \"Print\" >> beam.ParDo(FormatOutput())\n",
        "  )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "_xStwO3qubqr",
        "outputId": "5aeef601-c3e5-4b0f-e982-183ff36dc56e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Example: translate English to Spanish: How are you doing?\n",
            "Translated text: Cómo está acerca?\n",
            "--------------------------------------------------------------------------------\n",
            "Example: translate English to Spanish: This is the Apache Beam project.\n",
            "Translated text: Esto es el proyecto Apache Beam.\n",
            "--------------------------------------------------------------------------------\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Try with a different model\n",
        "One of the best parts of using RunInference is how easy it is to swap in different models. For example, if we wanted to use a larger model like DeepSeek-R1-Distill-Llama-8B outside of Colab (which has very tight memory constraints and limited GPU access), all we need to change is our ModelHandler:\n",
        "\n",
        "```\n",
        "model_handler = HuggingFacePipelineModelHandler(\n",
        "    task=PipelineTask.Translation_XX_to_YY,\n",
        "    model = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\",\n",
        "    load_pipeline_args={'framework': 'pt'},\n",
        "    inference_args={'max_length': 400}\n",
        ")\n",
        "```\n",
        "\n",
        "We can then run the exact same pipeline code and Beam will take care of the rest."
      ],
      "metadata": {
        "id": "LWNM81ivoZcF"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## RunInference with a pretrained model from Hugging Face Hub\n"
      ],
      "metadata": {
        "id": "KJEsPkXnUS5y"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "To use pretrained models directly from [Hugging Face Hub](https://huggingface.co/docs/hub/models), use either the `HuggingFaceModelHandlerTensor` model handler or the `HuggingFaceModelHandlerKeyedTensor` model handler. Which model handler you use depends on your input type:\n",
        "\n",
        "\n",
        "*   Use `HuggingFaceModelHandlerKeyedTensor` to run inference with models that uses keyed tensors as inputs.\n",
        "*   Use `HuggingFaceModelHandlerTensor` to run inference with models that uses tensor inputs, such as `tf.Tensor` or `torch.Tensor`.\n",
        "\n",
        "When you construct your pipeline, you might also need to use the following items:\n",
        "\n",
        "\n",
        "*   Use the `load_model_args` to provide optional arguments to load the model.\n",
        "*   Use the `inference_args` argument to do the inference.\n",
        "*   For TensorFlow models, specify the `framework='tf'`.\n",
        "*   For PyTorch models, specify the `framework='pt'`.\n",
        "\n",
        "\n",
        "\n",
        "The following language modeling task predicts the masked word in a sentence."
      ],
      "metadata": {
        "id": "mDcpG78tWcBN"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Create a model handler\n",
        "\n",
        "This example shows a masked language modeling task. These models take keyed tensors as inputs."
      ],
      "metadata": {
        "id": "dU6NDE4DaRuF"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "model_handler = HuggingFaceModelHandlerKeyedTensor(\n",
        "    model_uri=\"stevhliu/my_awesome_eli5_mlm_model\",\n",
        "    model_class=TFAutoModelForMaskedLM,\n",
        "    framework='tf',\n",
        "    load_model_args={'from_pt': True},\n",
        "    max_batch_size=1\n",
        ")"
      ],
      "metadata": {
        "id": "Zx1ep1UXWYrq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Define the input examples\n",
        "\n",
        "Use this code to define the input examples."
      ],
      "metadata": {
        "id": "D18eQZfgcIM6"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "text = ['The capital of France is Paris .',\n",
        "    'It is raining cats and dogs .',\n",
        "    'He looked up and saw the sun and stars .',\n",
        "    'Today is Monday and tomorrow is Tuesday .',\n",
        "    'There are 5 coconuts on this palm tree .']"
      ],
      "metadata": {
        "id": "vWI_A6VrcH-m"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Preprocess the input\n",
        "\n",
        "Edit the given input to replace the last word with a `<mask>`. Then, tokenize the input for doing inference."
      ],
      "metadata": {
        "id": "-62nvMSbeNBy"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "def add_mask_to_last_word(text: str) -> Tuple[str, str]:\n",
        "  \"\"\"Replace the last word of sentence with <mask> and return\n",
        "  the original sentence and the masked sentence.\"\"\"\n",
        "  text_list = text.split()\n",
        "  masked = ' '.join(text_list[:-2] + ['<mask>' + text_list[-1]])\n",
        "  return text, masked\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"stevhliu/my_awesome_eli5_mlm_model\")\n",
        "\n",
        "def tokenize_sentence(\n",
        "    text_and_mask: Tuple[str, str],\n",
        "    tokenizer) -> Tuple[str, Dict[str, tf.Tensor]]:\n",
        "  \"\"\"Convert string examples to tensors.\"\"\"\n",
        "  text, masked_text = text_and_mask\n",
        "  tokenized_sentence = tokenizer.encode_plus(\n",
        "      masked_text, return_tensors=\"tf\")\n",
        "\n",
        "  # Workaround to manually remove batch dim until we have the feature to\n",
        "  # add optional batching flag.\n",
        "  # TODO(https://github.com/apache/beam/issues/21863): Remove when optional\n",
        "  # batching flag added\n",
        "  return text, {\n",
        "      k: tf.squeeze(v)\n",
        "      for k, v in dict(tokenized_sentence).items()\n",
        "  }"
      ],
      "metadata": {
        "id": "d6TXVfWhWRzz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Postprocess the results\n",
        "\n",
        "Extract the result from the `PredictionResult` object. Then, format the output to print the actual sentence and the word predicted for the last word in the sentence."
      ],
      "metadata": {
        "id": "KnLtuYyTfC-g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class PostProcessor(beam.DoFn):\n",
        "  \"\"\"Processes the PredictionResult to get the predicted word.\n",
        "\n",
        "  The logits are the output of the BERT Model. To get the word with the highest\n",
        "  probability of being the masked word, take the argmax.\n",
        "  \"\"\"\n",
        "  def __init__(self, tokenizer):\n",
        "    super().__init__()\n",
        "    self.tokenizer = tokenizer\n",
        "\n",
        "  def process(self, element: Tuple[str, PredictionResult]) -> Iterable[str]:\n",
        "    text, prediction_result = element\n",
        "    inputs = prediction_result.example\n",
        "    logits = prediction_result.inference['logits']\n",
        "    mask_token_index = tf.where(inputs[\"input_ids\"] == self.tokenizer.mask_token_id)[0]\n",
        "    predicted_token_id = tf.math.argmax(logits[mask_token_index[0]], axis=-1)\n",
        "    decoded_word = self.tokenizer.decode(predicted_token_id)\n",
        "    print(f\"Actual Sentence: {text}\\nPredicted last word: {decoded_word}\")\n",
        "    print('-' * 80)"
      ],
      "metadata": {
        "id": "DnWlNV1kelnq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "### Run the pipeline\n",
        "\n",
        "Use the following code to run the pipeline."
      ],
      "metadata": {
        "id": "scepcVUpgD63"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "with beam.Pipeline() as beam_pipeline:\n",
        "  tokenized_examples = (\n",
        "      beam_pipeline\n",
        "      | \"CreateExamples\" >> beam.Create(text)\n",
        "      | 'AddMask' >> beam.Map(add_mask_to_last_word)\n",
        "      | 'TokenizeSentence' >>\n",
        "      beam.Map(lambda x: tokenize_sentence(x, tokenizer)))\n",
        "\n",
        "  result = (\n",
        "      tokenized_examples\n",
        "      | \"RunInference\" >> RunInference(KeyedModelHandler(model_handler))\n",
        "      | \"PostProcess\" >> beam.ParDo(PostProcessor(tokenizer))\n",
        "  )"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IEPrQGEWgBCo",
        "outputId": "218cc1f4-2613-4bf1-9666-782df020536b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Actual Sentence: The capital of France is Paris .\n",
            "Predicted last word:  Paris\n",
            "--------------------------------------------------------------------------------\n",
            "Actual Sentence: It is raining cats and dogs .\n",
            "Predicted last word:  dogs\n",
            "--------------------------------------------------------------------------------\n",
            "Actual Sentence: He looked up and saw the sun and stars .\n",
            "Predicted last word:  stars\n",
            "--------------------------------------------------------------------------------\n",
            "Actual Sentence: Today is Monday and tomorrow is Tuesday .\n",
            "Predicted last word:  Tuesday\n",
            "--------------------------------------------------------------------------------\n",
            "Actual Sentence: There are 5 coconuts on this palm tree .\n",
            "Predicted last word:  tree\n",
            "--------------------------------------------------------------------------------\n"
          ]
        }
      ]
    }
  ]
}
